import matplotlib.pyplot as plt
import numpy as np
import scienceplots

plt.style.use(['science', 'grid', 'ieee', 'scatter'])

results_music = np.load("results_music.npy")
results_aic =   np.load("results_aic.npy")
results_bic =   np.load("results_bic.npy")
separation = np.load("separations.npy")
samplesize = np.load("samplesize.npy")

n_trials = len(results_music)
print(results_music)
fig, ax = plt.subplots(1, 3, figsize=(8, 2.3))

a = ax[0]
a.set_title("Proposed", fontsize=15, weight='bold')
# for i in range(n_trials):
#     if results_music[i] == 2:
#         a.scatter(separation[i], samplesize[i], c='r')
#     else:
#         a.scatter(separation[i], samplesize[i], c='b')
# colors = ['r' if r == 2 else 'b' if r < 2 else 'g' for r in results_music]
a.scatter(np.log10(separation[results_music==2]), samplesize[results_music==2], c='red', marker='o')
a.scatter(np.log10(separation[results_music<2]), samplesize[results_music<2], c='blue', marker='v')
a.scatter(np.log10(separation[results_music>2]), samplesize[results_music>2], c='green', marker='^')
a.set_ylabel("$\log(n)$", fontsize=12)

a = ax[1]
a.set_title("AIC", fontsize=15, weight='bold')
# for i in range(n_trials):
#     if results_aic[i] == 2:
#         a.scatter(separation[i], samplesize[i], c='r')
#     else:
#         a.scatter(separation[i], samplesize[i], c='b')
a.scatter(np.log10(separation[results_aic==2]), samplesize[results_aic==2], c='red', marker='o')
a.scatter(np.log10(separation[results_aic<2]), samplesize[results_aic<2], c='blue', marker='v')
a.scatter(np.log10(separation[results_aic>2]), samplesize[results_aic>2], c='green', marker='^')

a = ax[2]
a.set_title("BIC", fontsize=15, weight='bold')
# for i in range(n_trials):
#     if results_bic[i] == 2:
#         a.scatter(separation[i], samplesize[i], c='r', label='success')
#     else:
#         a.scatter(separation[i], samplesize[i], c='b', label='failure')
# a.legend()
colors = ['r' if r == 2 else 'b' if r < 2 else 'g' for r in results_bic]
# a.scatter(separation, samplesize, c=colors)
a.scatter(np.log10(separation[results_bic==2]), samplesize[results_bic==2], c='red', marker='o')
a.scatter(np.log10(separation[results_bic<2]), samplesize[results_bic<2], c='blue', marker='v')
a.scatter(np.log10(separation[results_bic>2]), samplesize[results_bic>2], c='green', marker='^')



a.scatter([], [], c='red', label='Success')
a.scatter([], [], c='blue', marker='v', label='Failure (Less)')
a.scatter([], [], c='green', marker='^', label='Failure (More)')
a.legend(loc='upper right', bbox_to_anchor=(1.65, 1))
for a in ax:
    a.set_ylim(2.5, 6)
    # a.set_xlim(0.1, 6)
    a.set_xlabel("$\Delta$", fontsize=12)
    
plt.tight_layout()
plt.savefig("transitions_log.pdf")